{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PyTorch version: 1.0.0\n",
      "Numpy version: 1.15.4\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Check versions\n",
    "print(\"PyTorch version: {}\".format(torch.__version__))\n",
    "print(\"Numpy version: {}\".format(np.__version__))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean of features: [5.84333333 3.054      3.75866667 1.19866667]\n",
      "Std of features: [0.82530129 0.43214658 1.75852918 0.76061262]\n"
     ]
    }
   ],
   "source": [
    "# Load Iris Dataset\n",
    "FILE_PATH = \"/home/nobug-ros/dev_ml/PyTorch-for-Iris-Dataset/\"\n",
    "MAIN_FILE_NAME = \"iris_dataset.txt\"\n",
    "TRAIN_FILE_NAME = \"iris_train_dataset.txt\"\n",
    "TEST_FILE_NAME = \"iris_test_dataset.txt\"\n",
    "\n",
    "data = np.loadtxt(FILE_PATH+MAIN_FILE_NAME, delimiter=\",\")\n",
    "mean_data = np.mean(data[:,:4], axis=0)\n",
    "std_data = np.std(data[:,:4], axis=0)\n",
    "\n",
    "train_data = np.loadtxt(FILE_PATH+TRAIN_FILE_NAME, delimiter=\",\")\n",
    "test_data = np.loadtxt(FILE_PATH+TEST_FILE_NAME, delimiter=\",\")\n",
    "\n",
    "print(\"Mean of features: {}\".format(mean_data))\n",
    "print(\"Std of features: {}\".format(std_data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Standardize (Preprocess) train and test data\n",
    "for j in range(4):\n",
    "    for i in range(train_data.shape[0]):\n",
    "        train_data[i, j] = (train_data[i, j] - mean_data[j])/std_data[j]\n",
    "    for i in range(test_data.shape[0]):\n",
    "        test_data[i, j] = (test_data[i, j] - mean_data[j])/std_data[j] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert scaled test and train data into PyTorch tensor\n",
    "train_data = torch.Tensor(train_data)\n",
    "test_data = torch.Tensor(test_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Start building the Neural Network using PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Select device to compute cpu/gpu\n",
    "device = torch.device('cpu')\n",
    "# device = torch.device('gpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_sz, D_in, H, D_out = 4, 4, 8, 3\n",
    "# Use the nn package to define our model and loss function.\n",
    "model = torch.nn.Sequential(\n",
    "          torch.nn.Linear(D_in, H),\n",
    "          torch.nn.ReLU(),\n",
    "#           torch.nn.Linear(H, H),\n",
    "#           torch.nn.ReLU(),\n",
    "#           torch.nn.Linear(H, H),\n",
    "#           torch.nn.ReLU(),\n",
    "          torch.nn.Linear(H, D_out),\n",
    "          torch.nn.Softmax(dim=0),\n",
    "        )\n",
    "# loss_fn = torch.nn.BCELoss()\n",
    "# loss_fn = torch.nn.L1Loss(reduction='mean')\n",
    "loss_fn = torch.nn.MSELoss(reduction='mean')\n",
    "# loss_fn = torch.nn.PairwiseDistance(p=2)\n",
    "# loss_fn = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "# Use the optim package to define an Optimizer that will update the weights of\n",
    "# the model for us. Here we will use Adam; the optim package contains many other\n",
    "# optimization algoriths. The first argument to the Adam constructor tells the\n",
    "# optimizer which Tensors it should update.\n",
    "learning_rate = 1e-4\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZIAAAEKCAYAAAA4t9PUAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzt3X2cVnP+x/HXp5luKJFKpUIlEZImqUWEpVjKblGRCC021tpFrZW0y7rZ3ZbVrlaSflty09pCttqV3CXdIylT7iK3lUxUqs/vj++Z7Wqaaa65ua5zzcz7+Xicx3XOub7nzOdcM12fzvd7vt+vuTsiIiKlVS3uAEREpGJTIhERkTJRIhERkTJRIhERkTJRIhERkTJRIhERkTJRIhERkTJRIhERkTJRIhERkTLJjjuAdGjQoIEfcsghpTp206ZN1K5du3wDKgeZGhdkbmyKq2QUV8llamyljWvhwoVfunvDYgu6e6VfcnJyvLRmz55d6mNTKVPjcs/c2BRXySiuksvU2EobF7DAk/iOVdWWiIiUiRKJiIiUiRKJiIiUiRKJiIiUiRKJiIiUiRKJiIiUiRJJcXbsiDsCEZGMpkRSlK++gksvpcOQIaDpiEVEiqREUpR99oFnn6Xu8uUwb17c0YiIZCwlkqLUqAGXXBLWH3ww1lBERDKZEsmeXH55eJ08GTZujDcWEZEMpUSyJ4cdxoZjjoFvv4VJk+KORkQkIymRFOOTs88OK6reEhEplBJJMb48+WSoVw8WLQqLiIjsQomkGDtq1ICLLw4buisREdmNEkkyrrgivE6cCJs2xRuLiEiGUSJJxpFHQpcu8M038PjjcUcjIpJRlEiSlX9XouotEZFdKJEk6/zzoW5dmDsXli2LOxoRkYyhRJKs2rWhf/+wrrsSEZH/SWkiMbPuZrbCzHLNbGgh719vZm+b2Rtm9l8zOzjhvYFm9m60DEzYn2Nmb0bnvM/MLJXXsIv86q0JE2Dz5rT9WBGRTJayRGJmWcBooAfQFuhnZm0LFFsMdHT3dsCTwN3RsfsDtwLHA52AW82sXnTM34DBQOto6Z6qa9hNhw5hWb8epkxJ248VEclkqbwj6QTkuvtqd98KTAZ6JhZw99nu/m20+RrQLFo/E5jl7uvcfT0wC+huZk2Auu4+190dmAD0SuE17G7w4PCq6i0RESC1iaQp8FHC9ppoX1EuA54r5tim0Xqy5yx//frB3nvDnDmwcmVaf7SISCbKTuG5C2u7KHSGKDO7COgInFzMsSU552BCFRiNGjXihRdeKCbcwuXl5e12bJuTT6bJc8/x4fDhrL7yylKdt6wKiytTZGpsiqtkFFfJZWpsKY/L3VOyAF2AGQnbw4BhhZQ7HVgOHJCwrx8wJmF7TLSvCfBOUeWKWnJycry0Zs+evfvOuXPdwb1hQ/ctW0p97rIoNK4MkamxKa6SUVwll6mxlTYuYIEn8X2fyqqt+UBrM2thZjWAvsC0xAJmdmyUJM51988T3poBnGFm9aJG9jMISWkt8I2ZdY6e1roYmJrCayjc8cfDUUfBF1/AtGnFlxcRqcRSlkjcfRswhJAUlgOPu/syMxtpZudGxe4B6gBPmNkSM5sWHbsO+C0hGc0HRkb7AK4CxgK5wCp2tqukj5l6uouIRFLZRoK7TwemF9g3PGH99D0cOw4YV8j+BcBR5Rhm6Vx0Edx4I8yaBR98AAcfXPwxIiKVkHq2l9b++8N554E7PPJI3NGIiMRGiaQsLrssvD78MOzYEW8sIiIxUSIpi1NPDVVa778Pzz8fdzQiIrFQIimLatXg0kvD+kMPxRuLiEhMlEjK6tJLw1NcTz0F69YVX15EpJJRIimrgw6CH/4QtmyBSZPijkZEJO2USMpDfqO7qrdEpApSIikPPXuGx4GXLIFFi+KORkQkrZRIykPNmqGDIsC43fpQiohUakok5SW/emviRPjuu3hjERFJIyWS8tKuHXTsCBs2hCe4RESqCCWS8qRGdxGpgpRIylPfvlCrVujlvnp13NGIiKSFEkl52m8/6N07rI8fH2soIiLpokRS3hIHcty+Pd5YRETSQImkvJ18MrRqBWvWhLlKREQqOSWS8mYGgwaFdTW6i0gVoESSCgMHhpGBp06FL7+MOxoRkZRSIkmFpk2he3f4/ns1uotIpZfSRGJm3c1shZnlmtnQQt7vamaLzGybmfVO2N/NzJYkLJvNrFf03ngzey/hvfapvIZSu/LK8PrAA5o9UUQqtZQlEjPLAkYDPYC2QD8za1ug2IfAJcAu46+7+2x3b+/u7YFTgW+BmQlFbsh/392XpOoayuSss8LsiatWwcyZxZcXEamgUnlH0gnIdffV7r4VmAz0TCzg7u+7+xvAnv7L3ht4zt2/TV2oKZCVtfOu5K9/jTcWEZEUMndPzYlDVVV3d7882h4AHO/uQwopOx54xt2fLOS954E/ufszCWW7AFuA/wJD3X1LIccNBgYDNGrUKGfy5Mmluo68vDzq1KlTqmOrr19PlwsuwLZtY96kSWxu3LhU5ynvuFItU2NTXCWjuEouU2MrbVzdunVb6O4diy3o7ilZgD7A2ITtAcBfiig7HuhdyP4mwBdA9QL7DKgJPAIMLy6WnJwcL63Zs2eX+lh3d7/wQndwHzq0bOcpoMxxpVCmxqa4SkZxlVymxlbauIAFnsT3fSqrttYAzRO2mwGflPAc5wNPufv3+TvcfW10jVuAhwlVaJnr6qvD69ixYTpeEZFKJpWJZD7Q2sxamFkNoC8wrYTn6Ac8mrjDzJpErwb0At4qh1hTp0sXOOaY0J/k8cfjjkZEpNylLJG4+zZgCDADWA487u7LzGykmZ0LYGbHmdkaQjXYGDNbln+8mR1CuKOZU+DUE83sTeBNoAHwu1RdQ7kwg2uuCeujRkGK2qREROKSncqTu/t0YHqBfcMT1ucTqrwKO/Z9oGkh+08t3yjT4MILYdgwWLwYXnwxjMclIlJJqGd7OtSqBVddFdZHjYo3FhGRcqZEki5XXw01asC0aZCbG3c0IiLlRokkXRo1gv79QxvJfffFHY2ISLlRIkmnX/wivI4bBxs2xBuLiEg5USJJp3bt4LTTYNOmMJijiEgloESSbjfcEF5HjYJvK9bwYSIihVEiSbczzoCOHeHzz0NvdxGRCk6JJN3M4Oabw/o992jYFBGp8JRI4nDuuXDUUbBmDUyYEHc0IiJlokQSh2rV4Ne/Dut33gnbtsUbj4hIGSiRxOX886F1a1i9GiZNKr68iEiGUiKJS1bWzruSESNg69ZYwxERKS0lkjhddBEcfji8956e4BKRCkuJJE7Z2XD77WF95MjQUVFEpIJRIonbeefBccfBZ5/BvffGHY2ISIkpkcTNDH7/+7B+992wbl288YiIlJASSSY47TQ4/XT4+uudVV0iIhXEHhOJmWWZ2S/SFUyVdtdd4e7kvvvgnXfijkZEJGl7TCTuvh3oWdqTm1l3M1thZrlmNrSQ97ua2SIz22ZmvQu8t93MlkTLtIT9Lcxsnpm9a2aPmVmN0saXUTp0gMsvD50Tr7tOc7uLSIWRTNXWK2Z2v5mdZGYd8pfiDjKzLGA00ANoC/Qzs7YFin0IXAIU1iPvO3dvHy3nJuy/Cxjl7q2B9cBlSVxDxXD77bDvvjBjBjz9dNzRiIgkJZlE8gPgSGAk8Mdo+UMSx3UCct19tbtvBSZT4O7G3d939zeAHckEa2YGnAo8Ge16BOiVzLEVQsOG4TFgCJNgbd4cbzwiIkkoNpG4e7dCllOTOHdT4KOE7TXRvmTVMrMFZvaameUni/rABnfPH5yqpOfMfFddBUceGYZO+UMy+VpEJF7mxdTFm9m+wK1A12jXHGCku39dzHF9gDPd/fJoewDQyd2vKaTseOAZd38yYd+B7v6JmbUEngdOAzYCc9390KhMc2C6ux9dyDkHA4MBGjVqlDN58uQ9XmdR8vLyqFOnTqmOLa39Fi+m/fXXs6N6deaPHct3Bx2UEXElK1NjU1wlo7hKLlNjK21c3bp1W+juHYst6O57XIApwG1Ay2i5FfhnEsd1AWYkbA8DhhVRdjzQew/nGg/0Bgz4Esgu7GcUteTk5HhpzZ49u9THlsmgQe7gfuKJ7tu37/Z2bHElIVNjU1wlo7hKLlNjK21cwAIv5vvV3ZNqI2nl7rd6aOtY7e75SaU484HW0VNWNYC+wLRijgHAzOqZWc1ovQFwAvB2dGGzo6QCMBCYmsw5K5w//AEaN4aXX4YxY+KORkSkSMkkku/M7MT8DTM7AfiuuIM8tGMMAWYAy4HH3X2ZmY00s3Ojcx1nZmuAPsAYM1sWHX4EsMDMlhISx53u/nb03k3A9WaWS2gzeSiZC61w6tWD++8P6zfdFCbBEhHJQNlJlLkSmBC1lUB45HZgMid39+nA9AL7hieszweaFXLcq8Bu7R7Re6sJT4RVfj/5SRiL66mnQh+T554LnRZFRDJIcT3bqwFt3P0YoB3Qzt2P9fDIrqTD6NGw//6hb8no0XFHIyKym+J6tu8gVE/h7hvdfWNaopKdmjSBBx8M6zfcAG+/vefyIiJplkwbySwz+5WZNTez/fOXlEcmO/34xzBoUOig2L8/bNkSd0QiIv+TTCIZBPwMeBFYGC0LUhmUFOLPf4ZWrWDp0p1T9IqIZIBk2kgucvcWBZZkHv+V8rTPPvCPf4S53v/0J+q/8krcEYmIAMm1kWicjkzRuTPceScAh995Z5jrXUQkZslUbc00s59EAyZK3H75Szj3XKrn5UGfPmovEZHYJZNIrgeeALaa2UYz+8bM9PRWXMxg/Hi+a9wYFi4Mc5eIiMQomdF/93H3au5e3d3rRtt10xGcFKFePd4eMQJq1IAHHoC//S3uiESkCis2kVhwkZndEm03N7Oq0bM8g33Tps3O/iXXXAPPPx9vQCJSZSVTtfVXwii7/aPtPMLMhxK3iy+GG2+E7dtDe8mqVXFHJCJVUDKJ5Hh3/xmwGcDd1wOVY570yuCOO+BHP4J16+Ccc2Cjmq9EJL2SSSTfR/OvO4CZNSTJqXElDbKyYOLEMKvi8uXQr1+4QxERSZNkEsl9wFPAAWZ2O/AycEdKo5KSqVsXpk2D+vVh+nT4+c+hmJkvRUTKSzJPbU0EbgR+D6wFern7E6kOTEqoZUv45z/Dk1yjR8Pdd8cdkYhUEcnMR4K7vwO8k+JYpKy6dg3DqFxwAQwdCgceCAMGxB2ViFRyyVRtSUXSp08Y4BHCiMHPPRdvPCJS6SmRVEbXXhseC962LQxBrz4mIpJCSSUSMzvYzE6P1vcys31SG5aU2Z13wpVXhjlMzjkHXn457ohEpJJKpmf7FcCTwJhoVzPgX8mc3My6m9kKM8s1s6GFvN/VzBaZ2TYz652wv72ZzTWzZWb2hpldkPDeeDN7z8yWREv7ZGKpcsxCo/sll8C338JZZ8Hrr8cdlYhUQsnckfwMOAHYCODu7wIHFHdQ1PdkNNADaAv0M7O2BYp9CFwCTCqw/1vgYnc/EugO/NnM9kt4/wZ3bx8tS5K4hqqpWjUYOzb0LfnmGzjzTFi8OO6oRKSSSSaRbHH3rfkbZpZN1DmxGJ2AXHdfHR0/GeiZWMDd33f3NyjQwdHdV0YJC3f/BPgcaJjEz5SCsrLgkUfgvPNgwwY4/fQwarCISDkxL6bjmpndDWwALgauAa4G3nb3m4s5rjfQ3d0vj7YHEIZbGVJI2fHAM+7+ZCHvdQIeAY509x1R2S7AFuC/wFB3321SDjMbDAwGaNSoUc7kyZP3eJ1FycvLo06dOqU6NpVKGpd9/z1HjhhBg1dfZVvt2rzx+9+z8eijMyK2dFFcJaO4Si5TYyttXN26dVvo7h2LLejue1wIdy1XEOYkeTJatySO6wOMTdgeAPyliLLjgd6F7G8CrAA6F9hnQE1CghleXCw5OTleWrNnzy71salUqri2bHHv08cd3Pfe2/0//yn3uNwr2WeWBoqrZDI1LvfMja20cQELvJjvV3dPqmf7Dnd/0N37uHvvaD2Zqq01QPOE7WbAJ0kcB4CZ1QWeBX7j7q8lxLM2usYtwMOEKjRJRo0aMGkSDBwYGuDPPhuefTbuqESkgkvmqa03oyenEpeXzGyUmdXfw6HzgdZm1sLMagB9gWnJBBWVfwqY4AWGYzGzJtGrAb2At5I5p0Sys2HcOLjqqjBNb69eUMpqPxERSG6IlOeA7ex8sqpv9LqRUCV1TmEHufs2MxsCzACygHHuvszMRhJul6aZ2XGEhFEPOMfMbvPwpNb5QFegvpldEp3yEg9PaE2MRiA2YAlwZUkuWAhPc40eDbVrwx/+EJ7q+uCD0InRLO7oRKSCSSaRnODuJyRsv2lmr7j7CWZ20Z4OdPfpwPQC+4YnrM8nVHkVPO4fwD+KOOepScQsxTELAzs2bgy/+lUYm+u99+D++8Ndi4hIkpJ5/LeOmR2fvxE9RZXf/L8tJVFJepjBL38JTzwBNWvCmDGhF/w338QdmYhUIMkkksuBsVFv8veBscAVZlabMLS8VHS9e4fxuBo0gH//O4wi/MEHcUclIhVEMk9tzXf3o4H2QHt3b+fur7v7Jnd/PPUhSlr84Acwdy60bg1LlkBODsyaFXdUIlIBJDto49nAT4FrzWy4mQ0v7hipgA49FF57DXr0gK++CkOq3HEH7NDMyiJStGQe/30AuIDQq90IHQ0PTnFcEpf994dnnoFbbw3T9d58cxiK/uuv445MRDJUMnckP3D3i4H17n4bYXiS5sUcIxVZtWowYkRIKPvtB1OnQseO8Ja67IjI7pJJJJuj12/N7EDge6BF6kKSjHH22bBgAbRrB7m5cPzxoTNjUgMbiEhVkUwieToawv0eYBHwPvBoKoOSDNKqVWiEHzAgDKty2WXQty+sXx93ZCKSIfaYSMysGvBfd9/g7lMIbSOHJ3YqlCpg773DUPSPPAJ16sDjj0P79vDSS3FHJiIZYI+JxN13AH9M2N7i7mp1rYrM4OKLw8RYxx0HH34Ip5wCw4fD99/HHZ2IxCiZqq2ZZvaTaJBEqeoOPRReeQWGDQttJb/9beiD8s47cUcmIjFJJpFcT5iLZKuZbTSzb8xsY4rjkkxWvXroX/L883DQQaFB/thj4S9/UZ8TkSoomZ7t+7h7NXev7u51o+266QhOMtwpp8Abb4T5TTZvhmuvhTPPpOYXX8QdmYikUTIdEs3MLjKzW6Lt5tHAjSKw774wfjxMmQL168N//kPHQYPgUT3YJ1JVJFO19VdCJ8T+0XYeMDplEUnF9OMfhw6LZ59N9bw86N8/PCa8bl3ckYlIiiWTSI53958RdUx09/VAjZRGJRVT48bw9NOsuP76MGnWY4/B0UfDzJlxRyYiKZRMIvnezLIAB4hmJ1SLqhTOjLXnnANLl0KXLvDJJ2HwxyFDQodGEal0kkkk9xGmwz3AzG4HXgbuSGlUUvG1agUvvgi33x5mXBw9OjzZNX9+3JGJSDlL5qmticCNhEms1gK93P2JZE5uZt3NbIWZ5ZrZ0ELe72pmi8xsm5n1LvDeQDN7N1oGJuzPMbM3o3Pep/4tGSw7G379a3j9dWjbFlauDHcpI0aoE6NIJZLMU1v3Avu7+2h3v9/dlydz4qg6bDTQA2gL9DOztgWKfQhcAkwqcOz+wK3A8UAn4FYzqxe9/TdgMNA6WronE4/E6NhjYeFC+MUvYPt2uO02OOEEWLEi7shEpBwkU7W1CPhNdAdwj5l1TPLcnYBcd1/t7luByUDPxALu/r67v8HubS5nArPcfV3UuD8L6G5mTYC67j7X3R2YAPRKMh6JU61a8Kc/hU6MzZuHKq5jjw1VXhpNWKRCS6Zq6xF3P4uQGFYCd5nZu0mcuynwUcL2mmhfMoo6tmm0XppzSibo1i10YhwwAL77LjTCd+8OH38cd2QiUkrZJSh7KHA4cAjwdhLlC2u7SPa/nkUdm/Q5zWwwoQqMRo0a8cILLyT5o3eVl5dX6mNTKVPjgiRjGzSIhq1acdif/kT1mTP5/ogjWHnddXxx6qnxxhUDxVUymRoXZG5sKY/L3fe4AHcB7wL/BgYB+xV3THRcF2BGwvYwYFgRZccDvRO2+wFjErbHRPuaAO8UVa6oJScnx0tr9uzZpT42lTI1LvcSxvbJJ+49eriHCi73fv3c162LP640Ulwlk6lxuWdubKWNC1jgSXzfJ9NG8h7Qxd27u/s4d9+QZI6aD7Q2sxZmVgPoC0xL8tgZwBlmVi9qZD+DkJTWAt+YWefoaa2LgalJnlMyUZMm8Oyz8MADYd6TRx8NnRinTFHbiUgFkUwbyQPAdjPrFD2u29XMuiZx3DZgCCEpLAced/dlZjbSzM4FMLPjzGwN0AcYY2bLomPXAb8lJKP5wMhoH8BVwFggF1gFPFeyS5aMYwY//SksWQKdO4f2kt69Q0dGDU8vkvGKbSMxs8uBnwPNgCVAZ2AuUGxltrtPB6YX2Dc8YX1+dN7Cjh0HjCtk/wLgqOJ+tlRArVvDyy/D3/8ON98Ms2aFu5Nf/AJuuQX22SfuCEWkEMlUbf0cOA74wN27AccCGidcUiMrC666KnReHDw49Du55x447LCQYLZtiztCESkgmUSy2d03A5hZTXd/B2iT2rCkymvQAMaMgXnz4Pjj4dNPQ/VXu3bw9NNqPxHJIMkkkjVmth/wL2CWmU0FPkltWCKR446DuXPDSMItW8Ly5XDuuaE/yquvxh2diJBcY/t57r7B3UcAtwAPod7kkk5mcP75IYnce2+YQGvOnDDMSo8eYSwvEYlNMnck/+Puc9x9mochT0TSq0aNMJ3vqlXwm9+Exvd//ztUff3oR2E8LxFJuxIlEpGMsO++8NvfwnvvwbBhYRKtZ5+Fjh2hV6/wGLGIpI0SiVRc9evDHXeEhHLDDbDXXjB1ahgMsnfvMPWviKScEolUfA0bwt13w+rVoc9JrVqhZ3y7dvCTn6gNRSTFlEik8mjcOAxVv2oVXHMNVK8O//xnaEM55RT2f+01PTYskgJKJFL5HHgg3HdfqPK66SaoWxfmzKHdsGHhLmXCBNiq50VEyosSiVReBx4Id94JH30E99zDlgYNQrvJwIFhTvlRo+Cbb+KOUqTCUyKRyq9uXfjVr3ht0iR4+GE44ghYswauvx4OOiiM6/XZZ3FHKVJhKZFIleHVq8Mll4S7kmnT4MQTYcOG8OTXwQeHIVjeTWbyTxFJpEQiVU+1anDOOfDSS/DKK9CzJ2zZEgaFbNMmPDo8d27cUYpUGEokUrX94Afwr3+F4Vcuuyw86TVlStjfuXMY40sjDovskRKJCMDhh8PYsTt7y9erF0Ye7ts3DBZ5zz2hGkxEdqNEIpLowANDm8lHH8Ff/xrmQfnoI7jxRmjWLPRPyc2NO0qRjKJEIlKY2rXDBFvLl8Mzz8Bpp8GmTXD//SG59OoFs2erg6MIKU4kZtbdzFaYWa6ZDS3k/Zpm9lj0/jwzOyTaf6GZLUlYdphZ++i9F6Jz5r93QCqvQaq4atXg7LPhP/+BpUvh0ktDO8rUqXDqqaGD45gxIcmIVFEpSyRmlgWMBnoAbYF+Zta2QLHLgPXufigwCrgLwN0nunt7d28PDADed/fEIV0vzH/f3T9P1TWI7KJdOxg3Dj78EG69NQzJ8tZbcOWVodrrl78M432JVDGpvCPpBOS6++po/pLJQM8CZXoCj0TrTwKnmZkVKNMPeDSFcYqUTKNGMGIEfPABTJwYnu7asCGM83XooWEGx1mzVO0lVUYqE0lT4KOE7TXRvkLLuPs24GugfoEyF7B7Ink4qta6pZDEI5IeNWpA//6hz8n8+XDxxaHa6+mn4YwzoG1bGD1aw7BIpWeeov81mVkf4Ex3vzzaHgB0cvdrEsosi8qsibZXRWW+iraPB8a6+9EJxzR194/NbB9gCvAPd59QyM8fDAwGaNSoUc7kyZNLdR15eXnUqVOnVMemUqbGBZkbWzriqr5+PU2efZamU6dS88svAdhWuzafnnkmH593Ht81axZLXKWhuEouU2MrbVzdunVb6O4diy3o7ilZgC7AjITtYcCwAmVmAF2i9WzgS6LkFu0bBfx6Dz/jEuD+4mLJycnx0po9e3apj02lTI3LPXNjS2tcW7e6P/64+0knuYdKrrD06OE+fbr79u3xxFUCiqvkMjW20sYFLPAkvu9TWbU1H2htZi3MrAbQF5hWoMw0YGC03ht4PgoeM6sG9CG0rRDtyzazBtF6deBHgKbBk8xTvTr06QMvvgiLF4de87VqwXPPwVlnhUeI77pLg0VKpZCyROKhzWMI4a5jOfC4uy8zs5Fmdm5U7CGgvpnlAtcDiY8IdwXWuHviYzA1gRlm9gawBPgYeDBV1yBSLtq3D73m16wJyePgg8PkW0OHQrNmtB0xIjTO79gRd6QipZKdypO7+3RgeoF9wxPWNxPuOgo79gWgc4F9m4Cccg9UJB3q1w895H/5S5gxIwwS+cwzHDBnDsyZE4ZiueKKMEJx48ZxRyuSNPVsF0m3rKxQvfWvf8EHH/DeoEFhXpTVq8M4X82bhxGIdZciFYQSiUicmjblgwEDQhKZPj0MveIeRiA+44zQL+V3vwt9VkQylBKJSCbIyoIePeCpp0LP+d/9LrSlvPce3HILHHJIGO/r//5Pw7FIxlEiEck0Bx4Ypv9dtSq0pfTrF574ev750OmxcWMYNCg8Eabe85IBlEhEMlVWVqjemjQJ1q4Ng0N26QJ5eWHu+ZNPDlVft90W7lxEYqJEIlIR7LcfDB4Mr74K77wDv/51GChy9eow7lfLlnDKKeEx4/Xr445WqhglEpGKpk0buP12eP99mDkzjPdVq1Z4hPiKK8Kgkj17wuTJak+RtFAiEamosrLghz8MIxB/+ik89BCcfjps3w7TpoW2lUaN4KKLQo96zT0vKaJEIlIZ7LtvaICfNQs+/hjuvTcMb79pU0g0Z50FTZvCddfBwoVqpJdypUQiUtk0bgzXXhuGt1+1Cn772zC21+efhwTTsWMY4v6222DlyrijlUpAiUSkMmvZEn7zm9BA//rrcM010LBh2B4xIrS3dOgA99wT2lxESkGJRKTf2UZVAAAPqklEQVQqMIPjjoP77gtVX//+NwwcCHXrhtGJb7wRWrSAnJzQkL98edwRSwWiRCJS1VSvDmeeCePHh2Hsn3oKLrgA6tSBRYvCHUzbthw3cGDoGKk2FSmGEolIVVarVhjfa/Jk+OKLME3wpZfC/vtT+8MP4Y47QptKixZw/fXw8svhqTCRBEokIhLUqgU/+hGMGweffcaSP/4Rrr4amjQJg0aOGgUnnRSe/rryytCH5fvv445aMoASiYjsLjubDR06wOjRYUKuV1+FX/0q3Jl89lkYruXMM+GAA8L4X1OnwnffxR21xESJRET2rFq1MMbXPfeEx4kXLw4jEh95JGzYEEYk7tULGjQI0ws/+ihs3Bh31JJGSiQikjyzMHXwyJHw1lvhMeLf/z60o3z7LTz5ZBiypWFDOPtsePDB8JSYVGpKJCJSem3ahLnn588P7Sh//jN07RraTqZPDwNNNmsGxx4bngB75RU11ldCKU0kZtbdzFaYWa6ZDS3k/Zpm9lj0/jwzOyTaf4iZfWdmS6LlgYRjcszszeiY+8zMUnkNIpKkgw6Cn/88DB6ZP+z9OefA3nvDkiXhCbATTwztKv37wz/+AV9+GXfUUg5SlkjMLAsYDfQA2gL9zKxtgWKXAevd/VBgFHBXwnur3L19tFyZsP9vwGCgdbR0T9U1iEgpNWoU7kamTYOvvgoTdF17LbRqBevWhXaUAQNCUunSJQzjsnCh5qivoFJ5R9IJyHX31e6+FZgM9CxQpifwSLT+JHDanu4wzKwJUNfd57q7AxOAXuUfuoiUm1q1wgRd994LublhfK9Ro8LIxdWrw2uvwfDhoZ2ladMw+OSkSWpbqUBSmUiaAh8lbK+J9hVaxt23AV8D9aP3WpjZYjObY2YnJZRfU8w5RSSTtW4dRiGeOTPcrUydurMt5dNPw+yPF14Ytjt1CvPXv/IKbN0ad+RSBPMUDX1gZn2AM9398mh7ANDJ3a9JKLMsKrMm2l5FuJPJA+q4+1dmlgP8CzgSaAP83t1Pj8qfBNzo7ucU8vMHE6rAaNSoUc7kyZNLdR15eXnUqVOnVMemUqbGBZkbm+IqmbTH5U7t995j/3nz2G/pUvZbupSszZv/9/b2WrX4+uijWXvMMWw68US+Peig8BRZBqlsv8tu3botdPeOxRZ095QsQBdgRsL2MGBYgTIzgC7RejbwJVFyK1DuBaAj0AR4J2F/P2BMcbHk5OR4ac2ePbvUx6ZSpsblnrmxKa6SiT2uTZvcp0xxv/pq9yOOcA8jfu1cmjd3v/xy90mT3D/+ON5YI7F/ZkUobVzAAk/i+z6VVVvzgdZm1sLMagB9gWkFykwDBkbrvYHn3d3NrGHUWI+ZtSQ0qq9297XAN2bWOWpLuRiYmsJrEJG47L03/PjHoXf922+HJ8EmTOCz008P/VQ++ijMUd+/f2hbadMGfvrT0JD/ySdxR1+lZKfqxO6+zcyGEO46soBx7r7MzEYSstw04CHg/8wsF1hHSDYAXYGRZrYN2A5c6e7roveuAsYDewHPRYuIVHaNG8OAASxv3pxGXbuGHvazZoXHjV96KTTir1wJf/97KN+6NZxySlhOPjkkG0mJlCUSAHefDkwvsG94wvpmoE8hx00BphRxzgXAUeUbqYhUKNWqhblTcnJCh8jvvw9D4L/wQlhefhnefTcsDz4Yjjn00F0TS7Nm8cVfyaQ0kYiIpEX16nD88WG56SbYtm33xJKbG5axY8MxrVrtTCqdO4dEk2GN9xWFEomIVD7Z2eHR4U6dwuyP27aFqrD8xPLSS2EAylWr4KGHwjH164fynTuHhNSpE9SrF+dVVBhKJCJS+WVnh6mGjzsObrghJJYlS3YmlXnzwvD4zz0Xlnxt2oSkkp9cjj463P3ILpRIRKTqyc4OPek7dgzzrLiHQSfnzQs97efNC1VjK1aEZcKEcNxee4V2mfw7lpwcaNmyyleJKZGIiJjBIYeE5YILwr6tW2Hp0l2TS25uaG95+eWdx+67bxhav0MHGu21Vxg/rE0byMqK40pioUQiIlKYGjV2VocNGRL2ffklvP56SCoLFoSBJj/7LDyCPGcOR0AY5XjvveGYY6BDhzCEfocOYSKwGjXivKKUUSIREUlWgwZw1llhybd2bagGW7yYL2bOpOGHH4Zqsrlzw5KvevXQxtKhw87l6KND0qnglEhERMqiSZMwG+TZZ7PsxBM55ZRTwmCUixeHBBMlGVau3Lmdr1o1OOKIkFCOOiosRx4JLVpUqKoxJRIRkfJWvz6cfnpY8m3cGNpc8pPJokWwfDksWxaWRHvtBW3b7kws+a/Nm2dkw74SiYhIOtStCyedFJZ8330Hb721c1m2LLx+/HFof1m4cNdz1K4d7mDatt35evjh4Q4mxseSlUhEROKy1147G/QTrV+/804lP8ksXx4a9hcsCEui7OzQU79Nm53LYYeF14YNU34ZSiQiIpmmXr0wv/2JJ+66/6uvQkJ5++2d1WIrVsCHH+7s81LQnXeGfi8ppEQiIlJR1K9feIL59tswQOXKlTsTSv7SsmXKw1IiERGp6PL7rRxzzK773WHHjjAMTAqlcmIrERGJk1laHiNWIhERkTJRIhERkTJRIhERkTJJaSIxs+5mtsLMcs1saCHv1zSzx6L355nZIdH+H5rZQjN7M3o9NeGYF6JzLomWA1J5DSIismcpe2rLzLKA0cAPgTXAfDOb5u5vJxS7DFjv7oeaWV/gLuAC4EvgHHf/xMyOAmYATROOuzCau11ERGKWyjuSTkCuu692963AZKBngTI9gUei9SeB08zM3H2xu38S7V8G1DKzmimMVURESimViaQp8FHC9hp2vavYpYy7bwO+BuoXKPMTYLG7b0nY93BUrXWLWQaOYCYiUoWkskNiYV/wXpIyZnYkobrrjIT3L3T3j81sH2AKMACYsNsPNxsMDI4288yskLEDktKAUNWWaTI1Lsjc2BRXySiuksvU2Eob18HJFEplIlkDNE/YbgZ8UkSZNWaWDewLrAMws2bAU8DF7r4q/wB3/zh6/cbMJhGq0HZLJO7+d+DvZb0IM1vg7h3Lep7ylqlxQebGprhKRnGVXKbGluq4Ulm1NR9obWYtzKwG0BeYVqDMNGBgtN4beN7d3cz2A54Fhrn7K/mFzSzbzBpE69WBHwFvpfAaRESkGClLJFGbxxDCE1fLgcfdfZmZjTSzc6NiDwH1zSwXuB7If0R4CHAocEuBx3xrAjPM7A1gCfAx8GCqrkFERIqX0kEb3X06ML3AvuEJ65uBPoUc9zvgd0WcNqc8Y0xCmavHUiRT44LMjU1xlYziKrlMjS2lcZl7wfZvERGR5GmIFBERKRMlkj0oboiXNMbR3Mxmm9lyM1tmZj+P9o8ws48T2pHOiiG296OhbJaY2YJo3/5mNsvM3o1e66U5pjYJn8kSM9toZtfF9XmZ2Tgz+9zM3krYV+hnZMF90d/cG2bWIc1x3WNm70Q/+6nowRfM7BAz+y7hs3sgzXEV+bszs2HR57XCzM5Mc1yPJcT0vpktifan8/Mq6vshfX9j7q6lkAXIAlYBLYEawFKgbUyxNAE6ROv7ACuBtsAI4Fcxf07vAw0K7LsbGBqtDwXuivn3+CnhefhYPi+gK9ABeKu4zwg4C3iO0MeqMzAvzXGdAWRH63clxHVIYrkYPq9Cf3fRv4OlhAdxWkT/ZrPSFVeB9/8IDI/h8yrq+yFtf2O6IylaMkO8pIW7r3X3RdH6N4Sn4AqOEpBJEoe+eQToFWMspwGr3P2DuAJw9xeJ+kclKOoz6glM8OA1YD8za5KuuNx9pocnLgFeI/T/SqsiPq+i9AQmu/sWd38PyCX8201rXGZmwPnAo6n42Xuyh++HtP2NKZEULZkhXtLOwgjJxwLzol1DotvTcemuQoo4MNPCKM35Iwk0cve1EP7IgThHaO7Lrv+44/688hX1GWXS390gwv9c87Uws8VmNsfMToohnsJ+d5nyeZ0EfObu7ybsS/vnVeD7IW1/Y0okRUtmiJe0MrM6hGFhrnP3jcDfgFZAe2At4dY63U5w9w5AD+BnZtY1hhgKZaEj7LnAE9GuTPi8ipMRf3dmdjOwDZgY7VoLHOTuxxL6fE0ys7ppDKmo311GfF5AP3b9D0vaP69Cvh+KLFrIvjJ9ZkokRUtmiJe0sdCTfwow0d3/CeDun7n7dnffQeiYmZJb+j3xaJRmd/+cMKRNJ+Cz/Fvl6PXzdMcV6QEscvfPohhj/7wSFPUZxf53Z2YDCaNGXOhRpXpUdfRVtL6Q0BZxWLpi2sPvLhM+r2zgx8Bj+fvS/XkV9v1AGv/GlEiKlswQL2kR1b8+BCx39z8l7E+s1zyPNA8XY2a1LQyeiZnVJjTUvsWuQ98MBKamM64Eu/wvMe7Pq4CiPqNpwMXRkzWdga/zqyfSwcy6AzcB57r7twn7G1qYYwgzawm0BlanMa6ifnfTgL4WJslrEcX1erriipwOvOPua/J3pPPzKur7gXT+jaXjqYKKuhCeblhJ+N/EzTHGcSLh1jN/aJglUWz/B7wZ7Z8GNElzXC0JT8wsJcwbc3O0vz7wX+Dd6HX/GD6zvYGvgH0T9sXyeRGS2Vrge8L/Bi8r6jMiVDuMjv7m3gQ6pjmuXEL9ef7f2QNR2Z9Ev+OlwCLCxHPpjKvI3x1wc/R5rQB6pDOuaP944MoCZdP5eRX1/ZC2vzH1bBcRkTJR1ZaIiJSJEomIiJSJEomIiJSJEomIiJSJEomIiJSJEolIhjOzU8zsmbjjECmKEomIiJSJEolIOTGzi8zs9Wj+iTFmlmVmeWb2RzNbZGb/NbOGUdn2Zvaa7Zz3I3+uiEPN7D9mtjQ6plV0+jpm9qSFuUImRr2ZRTKCEolIOTCzI4ALCINYtge2AxcCtQnjfXUA5gC3RodMAG5y93aE3sX5+ycCo939GOAHhJ7UEEZ0vY4wz0RL4ISUX5RIkrLjDkCkkjgNyAHmRzcLexEGydvBzsH8/gH808z2BfZz9znR/keAJ6Jxy5q6+1MA7r4ZIDrf6x6N5WRhFr5DgJdTf1kixVMiESkfBjzi7sN22Wl2S4FyexqTaE/VVVsS1rejf7uSQVS1JVI+/gv0NrMD4H/zZR9M+DfWOyrTH3jZ3b8G1idMdjQAmONhDok1ZtYrOkdNM9s7rVchUgr6X41IOXD3t83sN4TZIqsRRoj9GbAJONLMFgJfE9pRIAzr/UCUKFYDl0b7BwBjzGxkdI4+abwMkVLR6L8iKWRmee5eJ+44RFJJVVsiIlImuiMREZEy0R2JiIiUiRKJiIiUiRKJiIiUiRKJiIiUiRKJiIiUiRKJiIiUyf8DeqiFahCM99wAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# m = torch.nn.Sigmoid()\n",
    "idx = np.arange(train_data.size()[0])\n",
    "avg_loss_list = list()\n",
    "epoch_list = list()\n",
    "for epoch in range(200):\n",
    "    total_loss = 0\n",
    "    np.random.shuffle(idx)\n",
    "    for id in idx:\n",
    "#     for id in range(30):\n",
    "        # Forward pass: compute predicted y by passing x to the model.\n",
    "        y_pred = model(train_data[id,0:4])\n",
    "        y = train_data[id,4:]\n",
    "#         y_pred = model(train_data[batch_sz*id:batch_sz*(id+1), 0:4])\n",
    "#         y = train_data[batch_sz*id:batch_sz*(id+1), 4:]\n",
    "\n",
    "#         print (\"Actual label: {}\".format(y))\n",
    "#         print(\"Predicted label: {}\".format(y_pred))\n",
    "\n",
    "        # Compute and print loss.\n",
    "#         loss = loss_fn(m(y_pred), y)\n",
    "        loss = loss_fn(y_pred, y)\n",
    "#         print(t, loss.item())\n",
    "        total_loss += loss\n",
    "\n",
    "        # Before the backward pass, use the optimizer object to zero all of the\n",
    "        # gradients for the Tensors it will update (which are the learnable weights\n",
    "        # of the model)\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # Backward pass: compute gradient of the loss with respect to model parameters\n",
    "        loss.backward()\n",
    "\n",
    "        # Calling the step function on an Optimizer makes an update to its parameters\n",
    "        optimizer.step()\n",
    "\n",
    "    avg_loss = total_loss/train_data.size()[0]\n",
    "    avg_loss_list.append(avg_loss)\n",
    "    epoch_list.append(epoch)\n",
    "#     print(\"Epoch: {}, Total in-sample error: {}\".format(epoch, avg_loss))\n",
    "    \n",
    "# Plot loss\n",
    "plt.plot(epoch_list, avg_loss_list, 'r-', lw=2)\n",
    "plt.xlabel(\"epoch\")\n",
    "plt.ylabel(\"average error\")\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([2.1211e-04, 1.0060e-01, 8.9918e-01], grad_fn=<SoftmaxBackward>)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_pred_test = model(test_data[21,0:4])\n",
    "y_pred_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.2        0.25980762 0.08291562 0.35707142]\n",
      "[0.31124749 0.18027756 0.16393596 0.25860201]\n",
      "[0.3   0.35  0.425 0.45 ]\n",
      "[0.425 0.45  0.225 0.425]\n"
     ]
    }
   ],
   "source": [
    "a = np.array([[0.1, 0.2, 0.5, 0.9],[0.5, 0.2, 0.4, 0.7],[0.1, 0.2, 0.5, 0.1],[0.5, 0.8, 0.3, 0.1]])\n",
    "a\n",
    "print(np.std(a, axis=0))\n",
    "print(np.std(a, axis=1))\n",
    "print(np.mean(a, axis=0))\n",
    "print(np.mean(a, axis=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
